package helper

import (
	"context"
	"regexp"
	"testing"
	"time"

	"github.com/stretchr/testify/require"
	"go.uber.org/zap"

	"github.com/milvus-io/milvus/client/v2/column"
	"github.com/milvus-io/milvus/client/v2/entity"
	client "github.com/milvus-io/milvus/client/v2/milvusclient"
	"github.com/milvus-io/milvus/pkg/v2/log"
	"github.com/milvus-io/milvus/tests/go_client/base"
	"github.com/milvus-io/milvus/tests/go_client/common"
)

func CreateContext(t *testing.T, timeout time.Duration) context.Context {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	t.Cleanup(func() {
		cancel()
	})
	return ctx
}

// var ArrayFieldType =
func GetAllArrayElementType() []entity.FieldType {
	return []entity.FieldType{
		entity.FieldTypeBool,
		entity.FieldTypeInt8,
		entity.FieldTypeInt16,
		entity.FieldTypeInt32,
		entity.FieldTypeInt64,
		entity.FieldTypeFloat,
		entity.FieldTypeDouble,
		entity.FieldTypeVarChar,
	}
}

func GetAllVectorFieldType() []entity.FieldType {
	return []entity.FieldType{
		entity.FieldTypeBinaryVector,
		entity.FieldTypeFloatVector,
		entity.FieldTypeFloat16Vector,
		entity.FieldTypeBFloat16Vector,
		entity.FieldTypeSparseVector,
	}
}

func GetAllScalarFieldType() []entity.FieldType {
	return []entity.FieldType{
		entity.FieldTypeBool,
		entity.FieldTypeInt8,
		entity.FieldTypeInt16,
		entity.FieldTypeInt32,
		entity.FieldTypeInt64,
		entity.FieldTypeFloat,
		entity.FieldTypeDouble,
		entity.FieldTypeVarChar,
		entity.FieldTypeArray,
		entity.FieldTypeJSON,
		entity.FieldTypeGeometry,
	}
}

func GetAllFieldsType() []entity.FieldType {
	allFieldType := GetAllScalarFieldType()
	allFieldType = append(allFieldType, entity.FieldTypeBinaryVector,
		entity.FieldTypeFloatVector,
		entity.FieldTypeFloat16Vector,
		entity.FieldTypeBFloat16Vector,
		// entity.FieldTypeSparseVector, max vector fields num is 4
	)
	return allFieldType
}

func GetAllNullableFieldType() []entity.FieldType {
	return []entity.FieldType{
		entity.FieldTypeBool,
		entity.FieldTypeInt8,
		entity.FieldTypeInt16,
		entity.FieldTypeInt32,
		entity.FieldTypeInt64,
		entity.FieldTypeFloat,
		entity.FieldTypeDouble,
		entity.FieldTypeVarChar,
		entity.FieldTypeJSON,
		entity.FieldTypeArray,
	}
}

func GetAllDefaultValueFieldType() []entity.FieldType {
	return []entity.FieldType{
		entity.FieldTypeBool,
		entity.FieldTypeInt8,
		entity.FieldTypeInt16,
		entity.FieldTypeInt32,
		entity.FieldTypeInt64,
		entity.FieldTypeFloat,
		entity.FieldTypeDouble,
		entity.FieldTypeVarChar,
	}
}

func GetInvalidPkFieldType() []entity.FieldType {
	nonPkFieldTypes := []entity.FieldType{
		entity.FieldTypeNone,
		entity.FieldTypeBool,
		entity.FieldTypeInt8,
		entity.FieldTypeInt16,
		entity.FieldTypeInt32,
		entity.FieldTypeFloat,
		entity.FieldTypeDouble,
		entity.FieldTypeString,
		entity.FieldTypeJSON,
		entity.FieldTypeGeometry,
		entity.FieldTypeArray,
	}
	return nonPkFieldTypes
}

func GetInvalidPartitionKeyFieldType() []entity.FieldType {
	nonPkFieldTypes := []entity.FieldType{
		entity.FieldTypeBool,
		entity.FieldTypeInt8,
		entity.FieldTypeInt16,
		entity.FieldTypeInt32,
		entity.FieldTypeFloat,
		entity.FieldTypeDouble,
		entity.FieldTypeJSON,
		entity.FieldTypeGeometry,
		entity.FieldTypeArray,
		entity.FieldTypeFloatVector,
	}
	return nonPkFieldTypes
}

func GetAllFieldsName(schema entity.Schema) []string {
	fields := make([]string, 0)
	for _, field := range schema.Fields {
		fields = append(fields, field.Name)
	}
	if schema.EnableDynamicField {
		fields = append(fields, common.DefaultDynamicFieldName)
	}
	return fields
}

// CreateDefaultMilvusClient creates a new client with default configuration
func CreateDefaultMilvusClient(ctx context.Context, t *testing.T) *base.MilvusClient {
	t.Helper()
	mc, err := base.NewMilvusClient(ctx, defaultClientConfig)
	common.CheckErr(t, err, true)

	t.Cleanup(func() {
		mc.Close(ctx)
	})

	return mc
}

// CreateMilvusClient create connect
func CreateMilvusClient(ctx context.Context, t *testing.T, cfg *client.ClientConfig) *base.MilvusClient {
	t.Helper()

	var (
		mc  *base.MilvusClient
		err error
	)
	mc, err = base.NewMilvusClient(ctx, cfg)
	common.CheckErr(t, err, true)

	t.Cleanup(func() {
		mc.Close(ctx)
	})

	return mc
}

// CollectionPrepare ----------------- prepare data --------------------------
type CollectionPrepare struct{}

var (
	CollPrepare CollectionPrepare
	FieldsFact  FieldsFactory
)

func mergeOptions(schema *entity.Schema, opts ...CreateCollectionOpt) client.CreateCollectionOption {
	//
	collectionOption := client.NewCreateCollectionOption(schema.CollectionName, schema)
	tmpOption := &createCollectionOpt{}
	for _, o := range opts {
		o(tmpOption)
	}

	if !common.IsZeroValue(tmpOption.shardNum) {
		collectionOption.WithShardNum(tmpOption.shardNum)
	}

	if !common.IsZeroValue(tmpOption.enabledDynamicSchema) {
		collectionOption.WithDynamicSchema(tmpOption.enabledDynamicSchema)
	}

	if !common.IsZeroValue(tmpOption.properties) {
		for k, v := range tmpOption.properties {
			collectionOption.WithProperty(k, v)
		}
	}

	if !common.IsZeroValue(tmpOption.consistencyLevel) {
		collectionOption.WithConsistencyLevel(*tmpOption.consistencyLevel)
	}

	return collectionOption
}

func (chainTask *CollectionPrepare) CreateCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient,
	cp *CreateCollectionParams, fieldOpts interface{}, schemaOpt *GenSchemaOption, opts ...CreateCollectionOpt,
) (*CollectionPrepare, *entity.Schema) {
	var fields []*entity.Field

	// Handle different parameter types for backward compatibility
	switch v := fieldOpts.(type) {
	case FieldOptions:
		fields = FieldsFact.GenFieldsForCollection(cp.CollectionFieldsType, v)
	case *GenFieldsOption:
		log.Warn("CreateCollection", zap.String("", "*GenFieldsOption has been deprecated, it is recommended to use FieldOptions"))
		// Convert *GenFieldsOption to FieldOptions for compatibility with GenFieldsForCollection
		// First generate fields to get field names, then create FieldOptions
		tempFields := FieldsFact.GenFieldsForCollection(cp.CollectionFieldsType, TNewFieldOptions())
		fieldOpts := TNewFieldOptions()
		for _, field := range tempFields {
			log.Info("CreateCollection", zap.String("name", field.Name))
			fieldOpts = fieldOpts.WithFieldOption(field.Name, v)
		}
		fields = FieldsFact.GenFieldsForCollection(cp.CollectionFieldsType, fieldOpts)
	default:
		log.Fatal("CreateCollection: fieldOpts must be either FieldOptions or *GenFieldsOption")
	}

	schemaOpt.Fields = fields
	if schemaOpt.CollectionName == "" {
		testName := regexp.MustCompile("[^a-zA-Z0-9]").ReplaceAllString(t.Name(), "_")
		schemaOpt.CollectionName = common.GenRandomString(testName, 6)
	}
	schema := GenSchema(schemaOpt)

	createCollectionOption := mergeOptions(schema, opts...)
	err := mc.CreateCollection(ctx, createCollectionOption)
	common.CheckErr(t, err, true)

	t.Cleanup(func() {
		// The collection will be cleanup after the test
		// But some ctx is setted with timeout for only a part of unittest,
		// which will cause the drop collection failed with timeout.
		ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Second*30)
		defer cancel()

		err := mc.DropCollection(ctx, client.NewDropCollectionOption(schema.CollectionName))
		common.CheckErr(t, err, true)
	})
	return chainTask, schema
}

func (chainTask *CollectionPrepare) InsertData(ctx context.Context, t *testing.T, mc *base.MilvusClient,
	ip *InsertParams, columnOpts interface{},
) (*CollectionPrepare, client.InsertResult) {
	if nil == ip.Schema || ip.Schema.CollectionName == "" {
		log.Fatal("[InsertData] Nil Schema is not expected")
	}

	var columns []column.Column
	var dynamicColumns []column.Column

	// Handle different parameter types for backward compatibility
	switch v := columnOpts.(type) {
	case ColumnOptions:
		columns, dynamicColumns = GenColumnsBasedSchema(ip.Schema, v)
	case *GenDataOption:
		log.Warn("InsertData", zap.String("", "*GenDataOption has been deprecated, it is recommended to use ColumnOptions"))
		// Convert *GenDataOption to ColumnOptions for compatibility
		columnOpts := TNewColumnOptions()
		for _, fieldName := range GetAllFieldsName(*ip.Schema) {
			columnOpts = columnOpts.WithColumnOption(fieldName, v)
		}
		columns, dynamicColumns = GenColumnsBasedSchema(ip.Schema, columnOpts)
	default:
		log.Fatal("InsertData: columnOpts must be either ColumnOptions or *GenDataOption")
	}

	insertOpt := client.NewColumnBasedInsertOption(ip.Schema.CollectionName).WithColumns(columns...).WithColumns(dynamicColumns...)
	if ip.PartitionName != "" {
		insertOpt.WithPartition(ip.PartitionName)
	}
	insertRes, err := mc.Insert(ctx, insertOpt)
	common.CheckErr(t, err, true)

	// Get the number of records from the first column or use a default
	nb := 0
	if len(columns) > 0 {
		nb = columns[0].Len()
	}
	require.Equal(t, nb, insertRes.IDs.Len())
	return chainTask, insertRes
}

func (chainTask *CollectionPrepare) FlushData(ctx context.Context, t *testing.T, mc *base.MilvusClient, collName string) *CollectionPrepare {
	flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
	common.CheckErr(t, err, true)
	err = flushTask.Await(ctx)
	common.CheckErr(t, err, true)
	return chainTask
}

func (chainTask *CollectionPrepare) CreateIndex(ctx context.Context, t *testing.T, mc *base.MilvusClient, ip *IndexParams) *CollectionPrepare {
	if nil == ip.Schema || ip.Schema.CollectionName == "" {
		log.Fatal("[CreateIndex] Empty collection name is not expected")
	}
	collName := ip.Schema.CollectionName
	mFieldIndex := ip.FieldIndexMap

	for _, field := range ip.Schema.Fields {
		if field.DataType >= 100 {
			if idx, ok := mFieldIndex[field.Name]; ok {
				log.Info("CreateIndex", zap.String("indexName", idx.Name()), zap.Any("indexType", idx.IndexType()), zap.Any("indexParams", idx.Params()))
				createIndexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, field.Name, idx))
				common.CheckErr(t, err, true)
				err = createIndexTask.Await(ctx)
				common.CheckErr(t, err, true)
			} else {
				idx := GetDefaultVectorIndex(field.DataType)
				log.Info("CreateIndex", zap.String("indexName", idx.Name()), zap.Any("indexType", idx.IndexType()), zap.Any("indexParams", idx.Params()))
				createIndexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, field.Name, idx))
				common.CheckErr(t, err, true)
				err = createIndexTask.Await(ctx)
				common.CheckErr(t, err, true)
			}
		}
	}
	return chainTask
}

func (chainTask *CollectionPrepare) Load(ctx context.Context, t *testing.T, mc *base.MilvusClient, lp *LoadParams) *CollectionPrepare {
	if lp.CollectionName == "" {
		log.Fatal("[Load] Empty collection name is not expected")
	}
	loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(lp.CollectionName).WithReplica(lp.Replica).WithLoadFields(lp.LoadFields...).
		WithSkipLoadDynamicField(lp.SkipLoadDynamicField).WithResourceGroup(lp.ResourceGroups...).WithRefresh(lp.IsRefresh))
	common.CheckErr(t, err, true)
	err = loadTask.Await(ctx)
	common.CheckErr(t, err, true)
	return chainTask
}
